/*
 * cache.c
 *
 * Copyright (C) 2016 Aleksandar Andrejevic <theflash@sdf.lonestar.org>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <device.h>
#include <cache.h>
#include <heap.h>

static void flush_cache_internal(cache_descriptor_t *cache, avl_node_t *node, void *context)
{
    if (node == NULL) return;

    cache_entry_t *entry = CONTAINER_OF(node, cache_entry_t, node);
    if (entry->dirty)
    {
        dword_t written;
        dword_t ret = cache->write_proc(context, entry->data, entry->address * cache->block_size, cache->block_size, &written);
        if ((ret == ERR_SUCCESS) || (written == cache->block_size)) entry->dirty = FALSE;
    }

    flush_cache_internal(cache, node->left, context);
    flush_cache_internal(cache, node->right, context);
}

static void cleanup_cache_internal(avl_node_t *node)
{
    if (node == NULL) return;

    cleanup_cache_internal(node->left);
    cleanup_cache_internal(node->right);

    cache_entry_t *entry = CONTAINER_OF(node, cache_entry_t, node);
    heap_free(&evictable_heap, entry);
}

static int compare_address(const void *key1, const void *key2)
{
    const qword_t first = *(const qword_t*)key1;
    const qword_t second = *(const qword_t*)key2;

    if (first < second) return -1;
    else if (first == second) return 0;
    else return 1;
}

void init_cache(cache_descriptor_t *cache,
                dword_t flags,
                dword_t block_size,
                read_write_buffer_proc_t read_proc,
                read_write_buffer_proc_t write_proc)
{
    cache->enabled = TRUE;
    lock_init(&cache->lock);
    cache->flags = flags;
    cache->block_size = block_size;
    cache->read_proc = read_proc;
    cache->write_proc = write_proc;

    AVL_TREE_INIT(&cache->entries, cache_entry_t, node, address, compare_address);
}

void cleanup_cache(cache_descriptor_t *cache)
{
    cleanup_cache_internal(cache->entries.root);
    cache->entries.root = NULL;
}

dword_t read_cache(cache_descriptor_t *cache, void *context, byte_t *buffer, qword_t offset, dword_t length, dword_t *bytes_read)
{
    dword_t ret = ERR_SUCCESS;
    qword_t i;
    qword_t first_block = offset / (qword_t)cache->block_size;
    qword_t last_block = (offset + length - 1) / (qword_t)cache->block_size;
    bool_t exclusive = FALSE;

    lock_acquire_shared(&cache->lock);
    *bytes_read = 0;

    for (i = first_block; i <= last_block; i++)
    {
        dword_t start_offset = 0, bytes_to_copy = cache->block_size;
        avl_node_t *element = avl_tree_lookup(&cache->entries, &i);

        if (element == NULL && !exclusive)
        {
            lock_release(&cache->lock);
            lock_acquire(&cache->lock);
            exclusive = TRUE;
            element = avl_tree_lookup(&cache->entries, &i);
        }

        if (element == NULL)
        {
            cache_entry_t *new_entry = (cache_entry_t*)heap_alloc(&evictable_heap, sizeof(cache_entry_t) + cache->block_size);
            if (new_entry != NULL)
            {
                new_entry->address = i;
                new_entry->dirty = FALSE;
            }

            ret = cache->read_proc(context, new_entry->data, i * (qword_t)cache->block_size, cache->block_size, NULL);
            if (ret != ERR_SUCCESS)
            {
                heap_free(&evictable_heap, new_entry);
                break;
            }

            avl_tree_insert(&cache->entries, &new_entry->node);
            element = &new_entry->node;
        }

        cache_entry_t *entry = CONTAINER_OF(element, cache_entry_t, node);

        if (first_block == last_block)
        {
            start_offset = (dword_t)(offset % (qword_t)cache->block_size);
            bytes_to_copy = length;
        }
        else if (i == first_block)
        {
            start_offset = (dword_t)(offset % (qword_t)cache->block_size);
            bytes_to_copy -= start_offset;
        }
        else if (i == last_block)
        {
            bytes_to_copy = length - *bytes_read;
        }

        memcpy(&buffer[*bytes_read], &entry->data[start_offset], bytes_to_copy);
        *bytes_read += bytes_to_copy;
    }

    lock_release(&cache->lock);
    return ret;
}

dword_t write_cache(cache_descriptor_t *cache, void *context, const byte_t *buffer, qword_t offset, dword_t length, dword_t *bytes_written)
{
    dword_t ret = ERR_SUCCESS;
    qword_t i;
    qword_t first_block = offset / (qword_t)cache->block_size;
    qword_t last_block = (offset + length - 1) / (qword_t)cache->block_size;

    lock_acquire(&cache->lock);
    *bytes_written = 0;

    for (i = first_block; i <= last_block; i++)
    {
        dword_t start_offset = 0, bytes_to_copy = cache->block_size;
        avl_node_t *element = avl_tree_lookup(&cache->entries, &i);

        if (element == NULL)
        {
            cache_entry_t *new_entry = (cache_entry_t*)heap_alloc(&evictable_heap, sizeof(cache_entry_t) + cache->block_size);
            if (new_entry == NULL)
            {
                ret = ERR_NOMEMORY;
                break;
            }

            new_entry->address = i;
            new_entry->dirty = FALSE;

            ret = cache->read_proc(context, new_entry->data, i * (qword_t)cache->block_size, cache->block_size, NULL);
            if (ret != ERR_SUCCESS)
            {
                heap_free(&evictable_heap, new_entry);
                break;
            }

            avl_tree_insert(&cache->entries, &new_entry->node);
            element = &new_entry->node;
        }

        cache_entry_t *entry = CONTAINER_OF(element, cache_entry_t, node);

        if (first_block == last_block)
        {
            start_offset = (dword_t)(offset % (qword_t)cache->block_size);
            bytes_to_copy = length;
        }
        else if (i == first_block)
        {
            start_offset = (dword_t)(offset % (qword_t)cache->block_size);
            bytes_to_copy -= start_offset;
        }
        else if (i == last_block)
        {
            bytes_to_copy = length - *bytes_written;
        }

        memcpy(&entry->data[start_offset], &buffer[*bytes_written], bytes_to_copy);
        *bytes_written += bytes_to_copy;

        if (cache->flags & CACHE_WRITE_THROUGH)
        {
            dword_t written;
            ret = cache->write_proc(context, entry->data, i * (qword_t)cache->block_size, cache->block_size, &written);
            if ((ret != ERR_SUCCESS) || (written != cache->block_size)) entry->dirty = TRUE;
        }
        else
        {
            entry->dirty = TRUE;
        }
    }

    lock_release(&cache->lock);
    return ret;
}

dword_t flush_cache(cache_descriptor_t *cache, void *context)
{
    flush_cache_internal(cache, cache->entries.root, context);
    return ERR_SUCCESS;
}
